""" Run counterfactuals.

Note: also has robustness to t=fortnightly

"""
if __name__ == '__main__':
    import pandas as pd
    import copy
    import sys
    sys.path.append('./')

    from src.models_new import simulation, counterfactuals
    from src.run_scripts import utils
    try:
        from src.interpolation.splines import UCGrid
    except:
        from ..interpolation.splines import UCGrid
    import numpy as np
    import scipy

    # %% READ IN THE INPUTS -------------------------------------------------------------
    data, delta, rho, c, weights = utils.read_in_data(use_myopic=True)
    (
        data_fortnight,
        delta_fortnight,
        rho_fortnight,
        c_fortnight,
        weights_fortnight
    ) = utils.read_in_data(time_period='fortnight', use_myopic=True)

    # Import value function grid
    search_grid_params_by_spec = dict()
    search_grid_by_spec = dict()
    search_value_by_spec = dict()
    for spec in ['low', 'mid', 'high']:
        search_grid_params_by_spec[spec] = np.load(
            f'./models/value_search/search_grid_{spec}_month.npy')
        search_grid_by_spec[spec] = UCGrid(
            tuple(search_grid_params_by_spec[spec][0, :]),
            tuple(search_grid_params_by_spec[spec][1, :]),
            tuple(search_grid_params_by_spec[spec][2, :]),
            tuple(search_grid_params_by_spec[spec][3, :])
        )
        # Import the value function grid just to
        search_value_by_spec[spec] = np.load(
            f'./models/value_search/search_value_{spec}_month.npy')

    #%% READ IN PARAMS ------------------------------------------------------------------
    params = pd.read_csv(
        './models/smm/params_smm_with_diff_new.csv',
        squeeze=True, index_col=[0]).to_dict()

    # Pre-compute some important quantities
    params['p_4'] = 1.0 - params['p_3'] - params['p_2']
    params['a_0'] = np.array([1.0, 1.0, 1.0])
    params['a_1'] = np.array([params['a_1_low'], params['a_1_mid'], params['a_1_high']])
    params['denom'] = (
        scipy.stats.norm.cdf(
            2.15,
            loc=params['mu_0'],
            scale=params['sigma_0']
        )
        - scipy.stats.norm.cdf(
            0,
            loc=params['mu_0'],
            scale=params['sigma_0']
        )
    )
    params_fortnight = copy.deepcopy(params)
    params_fortnight['d_0'] = params_fortnight['d_0'] / 2
    params_fortnight['d_1'] = params_fortnight['d_1'] / 2
    params_fortnight['m_2'] = params_fortnight['m_2'] / 2
    for i in [0, 1]:
        for j in ['low', 'mid', 'high']:
            params_fortnight[f'm_{i}_{j}'] = params_fortnight[f'm_{i}_{j}'] / 2

    # %% DO THE SIMULATION --------------------------------------------------------------
    args_config = {
        'params_fixed': {},
        'weights': weights,
        'mri_max': 2.15,
        'verbose': False,
        'verbose_output': True,
        'value_zero': False,
        'entry_prob_by_tau_by_ym': None,
        'tau_multiplier': 1,
        'x_names': list(params.keys()),
        'well_target': True
    }
    args_by_names_fortnight = {**args_config, **data_fortnight}
    args_by_names = {**args_config, **data}
    args_by_names_rig_target = copy.deepcopy(args_by_names)
    args_by_names_rig_target['well_target'] = False

    config_demand_smoothing = {
        #'sim_length': 1000,
        'mri_grid': 5,
        'g_grid': 2,
        'n_grid': 2
    }
    options = {
        'threads_per_worker': 8,
        'n_workers': 1
    }
    seeds = np.array(range(1))
    results_by_sim = dict()

    #%% DO SIMULATION -------------------------------------------------------------------

    # Monthly
    results_by_sim['benchmark'] = simulation.run_steps(
        x=list(params.values()),
        args_by_names=args_by_names
    )

    # Monthly, rig target
    results_by_sim['benchmark_rig_target'] = simulation.run_steps(
        x=list(params.values()),
        args_by_names=args_by_names_rig_target
    )

    # Fortnightly
    results_by_sim['benchmark_fortnight'] = simulation.run_steps(
        x=list(params.values()),
        args_by_names=args_by_names_fortnight
    )

    #%% DO NO SORTING COUNTERFACTUAL ----------------------------------------------------
    params_no_sorting = copy.deepcopy(params)
    params_no_sorting['gamma'] = 0.0
    params_no_sorting['gamma_negative'] = 0.0

    # Month
    args_by_names_no_sorting = copy.deepcopy(args_by_names)
    args_by_names_no_sorting['value_zero'] = True
    args_by_names_no_sorting['entry_prob_by_tau_by_ym'] = \
        results_by_sim['benchmark']['entry_prob_by_tau_by_ym']

    results_by_sim['no_sorting'] = simulation.run_steps(
        x=list(params_no_sorting.values()),
        args_by_names=args_by_names_no_sorting
    )

    args_by_names_no_sorting['well_target'] = False
    args_by_names_no_sorting['entry_prob_by_tau_by_ym'] = \
        results_by_sim['benchmark_rig_target']['entry_prob_by_tau_by_ym']
    results_by_sim['no_sorting_rig_target'] = simulation.run_steps(
        x=list(params_no_sorting.values()),
        args_by_names=args_by_names_no_sorting
    )

    # Fortnight
    params_no_sorting_fortnight = copy.deepcopy(params_fortnight)
    params_no_sorting_fortnight['gamma'] = 0.0
    params_no_sorting_fortnight['gamma_negative'] = 0.0

    args_by_names_no_sorting_fortnight = copy.deepcopy(args_by_names_fortnight)
    args_by_names_no_sorting_fortnight['value_zero'] = True
    args_by_names_no_sorting_fortnight['entry_prob_by_tau_by_ym'] = \
        results_by_sim['benchmark_fortnight']['entry_prob_by_tau_by_ym']

    results_by_sim['no_sorting_fortnight'] = simulation.run_steps(
        x=list(params_no_sorting.values()),
        args_by_names=args_by_names_no_sorting_fortnight
    )

    #%% DO INTERMEDIARY -----------------------------------------------------------------
    n_enter_by_ym_tau = counterfactuals.get_n_enter(
        entry_prob=results_by_sim['benchmark']['entry_prob_by_tau_by_ym'],
        data_g=data['state_data']['g'],
        params=params
    )

    results_by_sim['intermediary'] = counterfactuals.do_simulation_intermediary(
        data['state_data'], data['n_rigs'], n_enter_by_ym_tau, params)

    #%% DO DEMAND SMOOTHING -------------------------------------------------------------
    (
        results_by_sim['demand_smoothing'],
        data_smoothing,
        search_value_by_spec_final,
        first_stage_by_spec
    ) = counterfactuals.do_simulation_smoothing(
        params, state_data=copy.deepcopy(data['state_data']), n_rigs=data['n_rigs'],
        search_grid_by_spec=search_grid_by_spec,
        search_value_by_spec=search_value_by_spec,
        mri_max=args_config['mri_max'], config=config_demand_smoothing,
        seeds=seeds,
        options=options)

    #%% COMPUTE WELFARE -----------------------------------------------------------------
    import importlib
    importlib.reload(simulation)

    total_value_by_sim = dict()
    for sim in results_by_sim:
        print(sim)
        if (sim == 'benchmark_fortnight') | (sim == 'no_sorting_fortnight'):
            # Fortnight uses diff. data
            total_value_by_sim[sim] = simulation.get_welfare_from_simulation(
                results_by_sim[sim]['state_detail_by_ym'], params_fortnight, mri_max=2.15,
                n_rigs=data['n_rigs'], state_data=data_fortnight['state_data'])
        elif sim != 'demand_smoothing':
            # Use empirical gas price process
            total_value_by_sim[sim] = simulation.get_welfare_from_simulation(
                results_by_sim[sim]['state_detail_by_ym'], params, mri_max=2.15,
                n_rigs=data['n_rigs'], state_data=data['state_data'])
        elif sim == 'demand_smoothing':
            # Use flat natural gas price
            total_value_by_sim[sim] = simulation.get_welfare_from_simulation(
                results_by_sim[sim]['state_detail_by_ym'], params, mri_max=2.15,
                n_rigs=data['n_rigs'], state_data=data_smoothing)

    #%% COMPUTE DECOMPOSITION -----------------------------------------------------------
    # Set which is the baseline for the counterfactuals
    baseline_by_sim = {
        'benchmark': 'no_sorting',
        'benchmark_rig_target': 'no_sorting_rig_target',
        'benchmark_fortnight': 'no_sorting_fortnight',
        'no_sorting': 'no_sorting',
        'no_sorting_rig_target': 'no_sorting_rig_target',
        'no_sorting_fortnight': 'no_sorting_fortnight',
        'intermediary': 'benchmark',
        'demand_smoothing': 'benchmark'
    }

    df_all = data['state_data'][['g']]
    df_all_fortnight = data_fortnight['state_data'][['g']]
    n_rigs_total = data['n_rigs']['low'] + data['n_rigs']['mid'] + data['n_rigs']['high']
    df_all['opex'] = 0.032 * n_rigs_total
    df_all_fortnight['opex'] = 0.032 * n_rigs_total / 2

    moments_by_sim = dict()
    df_shares_by_sim = dict()
    entry_cost_by_sim = dict()
    for sim in baseline_by_sim:
        print(sim)

        # Get total entry costs
        if sim == 'demand_smoothing':
            (
                df_shares,
                entry_cost,
                n_enter
            ) = utils.get_entry_cost(
                data_g=data['state_data']['g'].mean(),
                params=params,
                shares_by_state=results_by_sim[sim]['shares_by_state']
            )
            df_shares_by_sim[sim] = copy.deepcopy(df_shares)
            df_all[f'{sim}_entry_cost'] = copy.deepcopy(entry_cost)
            df_all[f'{sim}_n_entry'] = copy.deepcopy(entry_cost)
        elif sim in ['benchmark_fortnight', 'no_sorting_fortnight']:
            (
                df_shares,
                entry_cost,
                n_enter
            ) = utils.get_entry_cost(
                data_g=data_fortnight['state_data']['g'].mean(),
                params=params_fortnight,
                shares_by_state=results_by_sim[sim]['shares_by_state']
            )
            df_shares_by_sim[sim] = copy.deepcopy(df_shares)
            df_all_fortnight[f'{sim}_entry_cost'] = copy.deepcopy(entry_cost)
            df_all_fortnight[f'{sim}_n_entry'] = copy.deepcopy(entry_cost)

        elif sim in ['benchmark', 'benchmark_rig_target', 'no_sorting', 'no_sorting_rig_target']:
            (
                df_shares,
                entry_cost,
                n_enter
            ) = utils.get_entry_cost(
                data_g=data['state_data']['g'],
                params=params,
                shares_by_state=results_by_sim[sim]['shares_by_state']
            )
            df_shares_by_sim[sim] = copy.deepcopy(df_shares)
            df_all[f'{sim}_entry_cost'] = copy.deepcopy(entry_cost)
            df_all[f'{sim}_n_entry'] = copy.deepcopy(entry_cost)

        elif sim == 'intermediary':
            total_entry_cost = df_shares_by_sim['benchmark']['number_enter'] * params['c']
            df_all[f'{sim}_entry_cost'] = copy.deepcopy(total_entry_cost)
            df_all[f'{sim}_n_entry'] = copy.deepcopy(
                df_shares_by_sim['benchmark']['number_enter'])
            df_all[f'{sim}_entry_cost'] = copy.deepcopy(total_entry_cost)

        # Get the total welfare series
        if (sim == 'benchmark_fortnight') | (sim == 'no_sorting_fortnight'):
            total_by_ym = 0
            for spec in ['low', 'mid', 'high']:
                total_by_ym += total_value_by_sim[sim][spec]
                df_all_fortnight[f'{sim}_total_value_{spec}'] = total_value_by_sim[sim][spec].values
            df_all_fortnight[f'{sim}_total_value'] = total_by_ym.values

            # Get the decomposition
            moments_by_sim[sim] = pd.DataFrame(results_by_sim[sim]['moments_by_state'])
            df_all_fortnight[f'{sim}_av_utilization'] = (
                data['n_rigs']['low'] * moments_by_sim[sim]['utilization_low'].values
                + data['n_rigs']['mid'] * moments_by_sim[sim]['utilization_mid'].values
                + data['n_rigs']['high'] * moments_by_sim[sim]['utilization_high'].values
            ) / n_rigs_total

        else:
            total_by_ym = 0
            for spec in ['low', 'mid', 'high']:
                total_by_ym += total_value_by_sim[sim][spec]
                df_all[f'{sim}_total_value_{spec}'] = total_value_by_sim[sim][spec].values
            df_all[f'{sim}_total_value'] = total_by_ym.values

            # Get the decomposition
            moments_by_sim[sim] = pd.DataFrame(results_by_sim[sim]['moments_by_state'])
            df_all[f'{sim}_av_utilization'] = (
                  data['n_rigs']['low'] * moments_by_sim[sim]['utilization_low'].values
                  + data['n_rigs']['mid'] * moments_by_sim[sim]['utilization_mid'].values
                  + data['n_rigs']['high'] * moments_by_sim[sim]['utilization_high'].values
            ) / n_rigs_total

            for spec in ['low', 'mid', 'high']:
                df_all[f'{sim}_av_utilization_{spec}'] = \
                    moments_by_sim[sim][f'utilization_{spec}'].values

    #%% GET TOTAL AGG UTILIZATION -------------------------------------------------------
    for sim in baseline_by_sim:
        if (sim == 'benchmark_fortnight') | (sim == 'no_sorting_fortnight'):
            df_all_fortnight[f'{sim}_same_util'] = df_all_fortnight[f'{sim}_total_value'] * (
                    df_all_fortnight[f'{baseline_by_sim[sim]}_av_utilization']
                    / df_all_fortnight[f'{sim}_av_utilization']
            )
        else:
            df_all[f'{sim}_same_util'] = df_all[f'{sim}_total_value'] * (
                df_all[f'{baseline_by_sim[sim]}_av_utilization']
                / df_all[f'{sim}_av_utilization']
            )

    #%% SAVE ----------------------------------------------------------------------------
    df_all.to_csv('./models/counterfactuals/counterfactual_results.csv')
    df_all_fortnight.to_csv('./models/counterfactuals/counterfactual_results_fortnight.csv')
